{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5c80f55c",
   "metadata": {},
   "source": [
    "## 8.3 批量标准化"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "569f34f9",
   "metadata": {},
   "source": [
    "论文可见于：[论文链接](https://arxiv.org/abs/1502.03167)\n",
    "\n",
    "批量标准化 (Batch Normalization) 是一种在深度学习中很常见的技巧，可以通过对每个输入批次的数据进行标准化来提高神经网络的训练速度和准确性。它通过计算数据的均值和方差，将数据按比例缩放和平移使其落在指定的范围内。这样做可以缓解深层神经网络中的梯度消失和爆炸问题，通常被放置在卷积层或全连接层的输入或输出之间，有助于模型收敛。\n",
    "\n",
    "批量标准化 (Batch Normalization) 是2015年由两位谷歌工程师合作提出的，作者分别是Sergey Ioffe和Christian Szegedy。截止目前为止，论文引用次数已经超过4万。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3baaee4",
   "metadata": {},
   "source": [
    "### 8.3.1 批量标准化基本思想"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbb0b045",
   "metadata": {},
   "source": [
    "批量标准化是为了解决深度神经网络训练过程中的两个主要问题：\n",
    "\n",
    "* 梯度消失/爆炸：当神经网络深度很大时，梯度在传播到网络的较深层时会变得非常小或者非常大，导致梯度消失或者爆炸。这可能会导致模型无法收敛。\n",
    "\n",
    "* 内部协变量偏移 (Internal Covariate Shift)：这是指在训练过程中，神经网络的中间层输入分布可能会发生变化。这是因为随着训练的进行，前一层的权重可能会发生变化，导致中间层的输入分布发生变化。这可能会导致模型训练变慢，并且可能会影响模型的准确性。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a65389b",
   "metadata": {},
   "source": [
    "### 8.3.2 批量标准化计算"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "916a086a",
   "metadata": {},
   "source": [
    "批量标准化 (Batch Normalization) 的计算其实很简单，经简化后，输出数据 y 的每个元素都是通过如下公式计算得到的：\n",
    "\n",
    "$$ y = \\gamma \\cdot \\frac{x - x_{mean}}{\\sqrt{x_{variance} + \\epsilon } } +\\beta  $$\n",
    "\n",
    "其中，$x$ 是输入数据，$ x_{mean}$ 和 $x_{variance}$ 分别是输入数据的均值和方差，$\\gamma$ 和 $\\beta$ 是批量标准化层的两个可训练参数，$\\epsilon$ 是一个很小的常数，用于避免方差为 0 的情况。\n",
    "\n",
    "批量标准化一般还有另外两个参数，移动平均值和移动方差，用于计算整个数据集的均值和方差。这个同学们了解即可，这里不做赘述。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03c6be7c",
   "metadata": {},
   "source": [
    "### 8.3.3 批量标准化代码实现"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9dff5e2",
   "metadata": {},
   "source": [
    "下面用批量标准化改造一下LeNet模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3ecb9916",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入必要的库\n",
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7fe12ce0",
   "metadata": {},
   "source": [
    "重新定义LeNet网络，并增加四个BatchNorm层。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "72e3dcde",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LeNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LeNet, self).__init__()\n",
    "        self.conv1 = nn.Sequential(\n",
    "            nn.Conv2d(1, 6, 5),\n",
    "            nn.BatchNorm2d(6)\n",
    "        )\n",
    "        self.conv2 = nn.Sequential(\n",
    "            nn.Conv2d(6, 16, 5),\n",
    "            nn.BatchNorm2d(16)\n",
    "        )\n",
    "        self.fc1 = nn.Sequential(\n",
    "            nn.Linear(16 * 4 * 4, 120),\n",
    "            nn.BatchNorm1d(120)\n",
    "        )\n",
    "        self.fc2 = nn.Sequential(\n",
    "            nn.Linear(120, 84),\n",
    "            nn.BatchNorm1d(84)\n",
    "        )\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.relu(self.conv1(x))\n",
    "        x = nn.functional.max_pool2d(x, 2)\n",
    "        x = torch.relu(self.conv2(x))\n",
    "        x = nn.functional.max_pool2d(x, 2)\n",
    "        x = x.view(-1, 16 * 4 * 4)\n",
    "        x = torch.relu(self.fc1(x))\n",
    "        x = torch.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ad16a80",
   "metadata": {},
   "source": [
    "替换LeNet结构重新进行训练，相比原始LeNet又有提升。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "96bf0385",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [02:32<00:00, 15.26s/it]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8+yak3AAAACXBIWXMAAAsTAAALEwEAmpwYAAAq/UlEQVR4nO3deXxU5dn/8c81k30hO4EsQCBhCTuERVE2NxZZilXBatmEqsTS9ulPrdr6PNY+2mr7VCsuFFFpVbQoCrigKAjKIgECshPWLCxhT4CQZe7fHzNACIEMMOFkJtf79ZpXztznnjMXB/hyc5b7iDEGpZRS3s9mdQFKKaU8QwNdKaV8hAa6Ukr5CA10pZTyERroSinlI/ys+uLY2FjTrFkzq75eKaW80qpVqw4aY+KqW2dZoDdr1oysrCyrvl4ppbySiOy+2Do95KKUUj5CA10ppXxEjYEuItNF5ICIrL/IehGRl0QkR0TWiUgXz5eplFKqJu6M0N8CBlxi/UAgzfWaCLx69WUppZS6XDUGujFmMXD4El2GATOM03IgUkQae6pApZRS7vHEMfREILfS+zxX2wVEZKKIZIlIVmFhoQe+Wiml1BnX9KSoMWaqMSbDGJMRF1ftZZRKKaWukCeuQ88Hkiu9T3K1KaXqOmPAOKq83G27jD4Y53oMnJ2xu3Kb8VAbNfQz5+qkcr2V11f5tVRtO7uNmvo4zu9Xua3VAEjs6vHfTk8E+hwgU0RmAj2AY8aYvR7YrqoPHA4wFeAoB0eFa7mi0nJN7Y4qfcqd23SUV3lVnFs+u61L9Kl2O5f4TOWazn5PlYA4L3Aqt1UXNJXbqul30W1c7LsuFrTq2hNo0NiaQBeR94C+QKyI5AFPAf4AxpjXgM+AQUAOcBIY6/EqVfUcFVBaDOWnoaIUKspcr1JwVFq+aHspVJSfW3aUX137eYHpuDA8jaNKAFZQabhWN4gNbH4gdudP25mffhd5b6vy3g/8g52fF3FuD3Eun/kpriOdldtwtbvVxiW2Kxf2t1Wq5aKvyp+voY8726naVl1d5+0Dqmmrrt+VtFH9Pj1TX9W2M38OLtmn8vvq2qruy8q/l7WnxkA3xoyqYb0BJnmsIl9mjDN8S4vhdJHrZ7Hr5/FKy8VQWuT8eV6/ovP7lJ2ohSIF7AGul9+5ZZtf9e0BYa5l/3OBdzYMbZwXjGJ3BaLdzXZb9ds8G7iuNrFVWr5UENcQzmJ3bl8pL2XZXC5eyRhn8B4vgOP5cOJgNcFcVE3bmTAuco5O3eEXDIFhzsAMDIOAcAiLh5hUV1v4uXV+QZXC1t/1OhOy/pcX0DZ77e5DpVSt8cpAP3qylMiQAM9u1Bg4dcQZ1GcC+3jBhculxdV/XmzO0K0cwoHhENbw/PA9E8bntYVfGN52r/ytUUpZyOtSY+ri7bz8TQ5LHu1PRLC/ex9yOODkwYuH9Jnl8pLzPyc2CG8MDRKgYRtIvdm53CABGiRCaNy5YPYPrvXjY0opdSleF+i9UmP53882869lu8jsn+Y8sVa8v0pIVw3rvc6TgZXZ/CDcFc6NO0GrQc6Qjkh0/myQAKENdaSslPIaXpdWbYuWMSv6NQIW78Wx5gS24v3OKyYqsweeG0Un9zy3XHV0rSfAlFI+xOsCneJ9tPPPI6simJzQ1rTs1KpKYCdCSLQe/lBK1TveF+hdxxDUdQyvTVvBlv1FLBnfjyB/vTJDKaW89phDZv9UCotO80FWbs2dlVKqHvDaQO+REk23ZlG8tmg7peV6C7NSSnltoIsImf3TKDhWwuw1eVaXo5RSlvPaQAfonRZLh6QIXlm0nfIKHaUrpeo3rw50ESGzXyq7D51k3jqd4FEpVb95daAD3NwmntaNwnl5YQ4ORx2buU8ppa4hrw90m02Y1C+VnAPFzN+wz+pylFLKMl4f6ACD2jemeWwo//gmB2N0lK6Uqp98ItDtNuGhfqls3HuchVsOWF2OUkpZwicCHWBYpwSSooJ56WsdpSul6ie3Al1EBojIFhHJEZHHqlnfVES+FpF1IrJIRJI8X+ql+dttPNi3Bdm5R1m6/dC1/nqllLJcjYEuInZgCjAQSAdGiUh6lW4vADOMMR2Ap4FnPV2oO37aNYn4BoH845ttVny9UkpZyp0RencgxxizwxhTCswEhlXpkw5841peWM36ayLQz84verdg+Y7DrNx12IoSlFLKMu4EeiJQeQasPFdbZWuBEa7lnwDhIhJTdUMiMlFEskQkq7Cw8ErqrdGo7k2ICQ3g5W9yamX7SilVV3nqpOhvgT4isgboA+QDFVU7GWOmGmMyjDEZcXFxHvrq8wUH2Ln/xuZ8u7WQdXlHa+U7lFKqLnIn0POB5Ervk1xtZxljCowxI4wxnYEnXG1HPVXk5bq3ZxMigv11lK6UqlfcCfSVQJqIpIhIADASmFO5g4jEisiZbf0OmO7ZMi9PeJA/Y3s148uN+9m877iVpSil1DVTY6AbY8qBTGA+sAn4wBizQUSeFpGhrm59gS0ishWIB/5US/W6bcz1zQgL9GPKwu1Wl6KUUteEW4+gM8Z8BnxWpe0PlZZnAbM8W9rViQwJ4L7rmvLat9v51c1ptIgLs7okpZSqVT5zp2h1xt+QQqCfjVcX6ShdKeX7fDrQY8MCuad7U2avySf38Emry1FKqVrl04EOMLF3c+wivPqtjtKVUr7N5wO9UUQQd2YkMSsrj73HTlldjlJK1RqfD3SAB/q0oMIYpi7eYXUpSilVa+pFoCdHh/CTzom898MeCotOW12OUkrVinoR6AAP9W3B6XIHb3y30+pSlFKqVtSbQG8eF8btHRL417JdHD1ZanU5SinlcfUm0AEy+6VyorSCN7/fZXUpSinlcfUq0Fs1Cue2tvG8+f1OikrKrC5HKaU8ql4FOkBmvzSOl5Tzr+W7rS5FKaU8qt4FevukCPq2imPakp2cLC23uhyllPKYehfoAA/3T+XwiVLe+yG35s5KKeUl6mWgd20azXXNY5i6eDslZRc8WEkppbxSvQx0cI7S9x8/zaxVeVaXopRSHlFvA/26FjF0aRLJq4u2U1bhsLocpZS6avU20EWEh/unkX/0FB+vya/5A0opVce5FegiMkBEtohIjog8Vs36JiKyUETWiMg6ERnk+VI9r2+rONomNOCVRdupcBiry1FKqatSY6CLiB2YAgwE0oFRIpJepduTOJ812hnnQ6Rf8XShtcE5Sk9l58ETfPrjXqvLUUqpq+LOCL07kGOM2WGMKQVmAsOq9DFAA9dyBFDguRJr163pjUhrGMaUb3Jw6ChdKeXF3An0RKDyBdt5rrbK/hu4V0TycD5M+uHqNiQiE0UkS0SyCgsLr6Bcz7PZhMz+qWzZX8RXm/ZbXY5SSl0xT50UHQW8ZYxJAgYB/xKRC7ZtjJlqjMkwxmTExcV56Kuv3uD2jWkWE8LL3+RgjI7SlVLeyZ1AzweSK71PcrVVNh74AMAYswwIAmI9UeC14Ge38VDfVH7MP8a3W+vG/xyUUupyuRPoK4E0EUkRkQCcJz3nVOmzB7gJQETa4Ax0r0rG4Z0TSYwM5h86SldKeakaA90YUw5kAvOBTTivZtkgIk+LyFBXt/8CJojIWuA9YIzxslQM8LPxQJ/mrNp9hOU7DltdjlJKXTaxKnczMjJMVlaWJd99MSVlFdz4l4W0jA/jnft7Wl2OUkpdQERWGWMyqltXb+8UrU6Qv51f9G7O9zmHWLX7iNXlKKXUZdFAr+KeHk2ICvFnysIcq0tRSqnLooFeRUiAH/ff2JxvNh9gff4xq8tRSim3aaBX477rmhIe5KejdKWUV9FAr0aDIH/GXt+Mz9fvY+v+IqvLUUopt2igX8TYXimEBNh1lK6U8hoa6BcRFRrAfT2bMndtATsPnrC6HKWUqpEG+iWMvzEFf7uNVxfpKF0pVfdpoF9Cw/AgRnVvwker88k7ctLqcpRS6pI00GswsXdzROD1b3dYXYpSSl2SBnoNEiKD+WnXJN7PymX/8RKry1FKqYvSQHfDg31SqXAY/rlYR+lKqbpLA90NTWJCGNYxgXdW7OFQ8Wmry1FKqWppoLvpoX4tKCmvYPr3O60uRSmlqqWB7qbUhuEMateYt5fu5tjJMqvLUUqpC2igX4ZJ/VIpPl3O28t2WV2KUkpdQAP9MqQnNODmNg2Z/v1Oik+XW12OUkqdx61AF5EBIrJFRHJE5LFq1v+fiGS7XltF5KjHK60jJvVL5ejJMt5ZvtvqUpRS6jw1BrqI2IEpwEAgHRglIumV+xhjfm2M6WSM6QT8A/ioFmqtEzo3ieLGtFj+uWQHJWUVVpejlFJnuTNC7w7kGGN2GGNKgZnAsEv0H4XzQdE+6+H+aRwsLmXmD3usLkUppc5yJ9ATgdxK7/NcbRcQkaZACvDNRdZPFJEsEckqLCy83FrrjO4p0XRPiebFr7exvbDY6nKUUgrw/EnRkcAsY0y1xyKMMVONMRnGmIy4uDgPf/W19ec7OmAT4edv/MC+YzolgFLKeu4Eej6QXOl9kqutOiPx8cMtZ6TEhvL2uO4cO1XGz6ev4OjJUqtLUkrVc+4E+kogTURSRCQAZ2jPqdpJRFoDUcAyz5ZYd7VLjGDqfV3ZdfAk49/O4lSpniRVSlmnxkA3xpQDmcB8YBPwgTFmg4g8LSJDK3UdCcw0xpjaKbVuuj41lhdHdmL1niNMenc1ZRUOq0tSStVTYlX+ZmRkmKysLEu+uza8s2I3T8xez4jOibxwZ0dsNrG6JKWUDxKRVcaYjOrW+V3rYnzVz3o05XBxKX/9aivRoQE8MbgNIhrqSqlrRwPdgzL7p3LoRCnTvttJbHggD/RpYXVJSql6RAPdg0SEP9yezuETpTz3+WaiQwK4q1tyzR9USikP0ED3MJtNeOHOjhw5WcpjH60jKjSAW9LjrS5LKVUP6GyLtSDAz8Zr93alfVIkme+uZsWOQ1aXpJSqBzTQa0looB9vjulGUlQw98/IYmPBcatLUkr5OA30WhQdGsCM8T0IC/Rj9Js/sOfQSatLUkr5MA30WpYYGcyMcd0pq3Bw3/QVFBbpQ6aVUrVDA/0aSIsP543R3Thw/DSjp//A8RJ9JqlSyvM00K+Rrk2jeOXeLmzdX8TEGVn6cAyllMdpoF9D/Vo15IU7O7J8x2Emz1xDhaNeTXujlKplGujX2PDOifzh9nTmb9jPkx//SD2by0wpVYv0xiILjLshhUMnTjNl4XZiQgP57W2trC5JKeUDNNAt8ttbW3GouJSXF+YQHRrAuBtSrC5JKeXlNNAtIiI8M7wdR06W8vS8jUSHBjC8c7WPalVKKbfoMXQL+dltvDiyMz1Sovntf9ayaMsBq0tSSnkxDXSLBfnb+efoDFrGh/Pgv1ezZs8Rq0tSSnkptwJdRAaIyBYRyRGRxy7S5y4R2SgiG0TkXc+W6dsaBPnz9rjuNGwQyNi3VpJzoMjqkpRSXqjGQBcROzAFGAikA6NEJL1KnzTgd0AvY0xb4FeeL9W3xYUH8q9xPfC327jvjR8oOHrK6pKUUl7GnRF6dyDHGLPDGFMKzASGVekzAZhijDkCYIzRg8FXoElMCG+P7U5xSTn3vbGCIydKrS5JKeVF3An0RCC30vs8V1tlLYGWIvK9iCwXkQHVbUhEJopIlohkFRYWXlnFPi49oQHTRmeQe+QUY99ayYnT5VaXpJTyEp46KeoHpAF9gVHAP0UksmonY8xUY0yGMSYjLi7OQ1/te3o0j+HlUZ1Zl3eUB99ZTWm5w+qSlFJewJ1AzwcqPxgzydVWWR4wxxhTZozZCWzFGfDqCt3athHPjejA4q2F/PY/a3HovC9KqRq4E+grgTQRSRGRAGAkMKdKn49xjs4RkVich2B2eK7M+umubsk8OqA1c9YW8PS8jTrvi1Lqkmq8U9QYUy4imcB8wA5MN8ZsEJGngSxjzBzXultFZCNQAfw/Y4w+SNMDHujTnEPFp5n23U5iwwLI7K//8VFKVc+tW/+NMZ8Bn1Vp+0OlZQP8xvVSHiQiPD6oDYdPlPLCl1uJDg3knh5NrC5LKVUH6VwuXsBmE/780w4cOVnKkx//SFSIPwPbN7a6LKVUHaO3/nsJf7uNKT/rQqfkSCbPzGbp9oNWl6SUqmM00L1ISIAf08d0o2lMCBNnrGJ9/jGrS1JK1SEa6F4mMiSAGeO7ExHsz+jpP7Dz4AmrS1JK1REa6F6ocUQwM8Z3xwD3vbGC/cdLrC5JKVUHaKB7qRZxYbw5phuHT5Qy4pWlrNZpd5Wq9zTQvVjH5Ejem9ATmw3uem0Zr3+7Xe8oVaoe00D3ch2TI5n38I3c2jaeZz/fzLi3V3Ko+LTVZSmlLKCB7gMigv2Zck8XnhnejqXbDzHopSUs36E36ipV32ig+wgR4d6eTfn4oV6EBvhxzz+X8+KCbVToIRil6g0NdB+TntCAuQ/fwLBOifzfgq3cO20FB/QqGKXqBQ10HxQa6Mff7urI8z/tQHbuUQa+uITFW/WBIkr5Og10HyUi3JmRzJzMXsSGBfLz6T/wly82U16hD8tQyldpoPu4tPhwPsnsxajuTXhl0XbunrqcfH0AtVI+SQO9Hgjyt/PsiPa8NKozW/YVMejFJXy1cb/VZSmlPEwDvR4Z2jGBeQ/fQHJ0MBNmZPH03I36vFKlfIgGej3TLDaUDx+8njHXN2P69zv56WtL2X1IJ/hSyhe4FegiMkBEtohIjog8Vs36MSJSKCLZrtf9ni9VeUqgn53/HtqW1+/ryq6DJ7j9pe+Yt67A6rKUUlepxkAXETswBRgIpAOjRCS9mq7vG2M6uV7TPFynqgW3tW3EZ5NvJC0+jMx31/D47B8pKauwuiyl1BVyZ4TeHcgxxuwwxpQCM4FhtVuWulaSokJ4/xfX8UCfFry7Yg/Dp3xPzoFiq8tSSl0BdwI9Ecit9D7P1VbVHSKyTkRmiUhydRsSkYkikiUiWYWFeqNLXeFvt/HYwNa8NbYbB4pOM+Qf3/Hhqjyry1JKXSZPnRSdCzQzxnQAvgLerq6TMWaqMSbDGJMRFxfnoa9WntK3VUM+n3wjHZIi+K//rOU3H2Rz4nS51WUppdzkTqDnA5VH3EmutrOMMYeMMWfmbJ0GdPVMeepai28QxLsTejL5pjRmr8lnyMvfsWnvcavLUkq5wZ1AXwmkiUiKiAQAI4E5lTuISONKb4cCmzxXorrW7Dbh17e05J37e1BcUs6wKd/zzordGKMzNypVl9UY6MaYciATmI8zqD8wxmwQkadFZKir2y9FZIOIrAV+CYyprYLVtXN9i1g+m3wjPZvH8MTs9WS+t4bjJWVWl6WUugixatSVkZFhsrKyLPludXkcDsPri3fwwpdbSIwM5uV7OtMhKdLqspSql0RklTEmo7p1eqeoqpHNJjzYtwUf/KIn5RUO7nh1KW98t1MPwShVx2igK7d1bRrNZ5NvpE/Lhvxx3kYmzFjF0ZOlVpellHLRQFeXJTIkgH/+vCtPDUnn260HGPTiErJ2Hba6LKUUGujqCogIY3ul8OGD1+Nnt3H31OVMWZiDQ59fqpSlNNDVFeuQFMm8X97AgHaNeH7+Fm748zf872eb+DHvmB5fV8oCepWLumrGGL5Yv49Zq/JYvK2QsgpDs5gQhnRMYEjHBFrGh1tdolI+41JXuWigK486erKU+Rv2MXftXpZuP4jDQKv4cIZ0bMztHRJoFhtqdYlKeTUNdGWJwqLTfL5+L3PXFrBy1xEAOiRFMKRDAoM7NCYhMtjiCpXyPhroynIFR0/x6bq9zF1XwLq8YwB0axbFkI4JDGzXmLjwQIsrVMo7aKCrOmXXwRPMW1fA3LV72bK/CJs4pxkY0rExA9o2JiLE3+oSlaqzNNBVnbVlX5Er3AvYdegk/nahd1ocQzomcHN6PGGBflaXqFSdooGu6jxjDOvzjzN3XQHz1hZQcKyEQD8bN7VpyJAOCfRr3ZAgf7vVZSplOQ105VUcDsPqPUeYu7aAT3/cy8HiUkID7NzathFDOjbmhtQ4Avz0FgpVP2mgK69VXuFgxc7DzF1bwOfr93HsVBkRwf4MbNeIIR0T6Nk8BrtNrC5TqWtGA135hNJyB9/lFDJ37V6+3LCPE6UVxIYFMri9M9y7NInCpuGufJwGuvI5JWUVLNx8gLnrCvh60wFOlztIiAji9o4J3NezKcnRIVaXqFStuOpAF5EBwIuAHZhmjHnuIv3uAGYB3Ywxl0xrDXTlKcWny1mwcT9z1xaweFshDgNDOybwYN8WOu2A8jlXFegiYge2ArcAeTifMTrKGLOxSr9w4FMgAMjUQFdW2H+8hGlLdvDOij2cLK3g1vR4JvVLpWNypNWlKeURV/vEou5AjjFmhzGmFJgJDKum3x+BPwMlV1ypUlcpvkEQTwxO5/tH+zP5pjRW7DzMsCnfc++0FSzdflBngVQ+zZ1ATwRyK73Pc7WdJSJdgGRjzKcerE2pKxYVGsCvb2nJ94/15/FBrdmyv4h7/rmCEa8uZcHG/Rrsyidd9cW8ImID/gb8lxt9J4pIlohkFRYWXu1XK1WjsEA/JvZuwZJH+vHM8HYUFp3m/hlZDHxxCZ9k51Ne4bC6RKU8xp1j6NcB/22Muc31/ncAxphnXe8jgO1AsesjjYDDwNBLHUfXY+jKCmUVDuatK+CVhdvZdqCYpjEhPNCnBSO6JBLop3eiqrrvak+K+uE8KXoTkI/zpOg9xpgNF+m/CPitnhRVdZnDYfhq036mLMxhXd4x4hsEMuHG5tzTowkhATp/jKq7ruqkqDGmHMgE5gObgA+MMRtE5GkRGerZUpW6Nmw24ba2jfhkUi/+Pb4HzWPDeObTTfR67hte+nobx06WWV2iUpdNbyxSymXV7iO8uiiHBZsOEBpg597rmjL+hhQahgdZXZpSZ+mdokpdhk17j/Pqou3MW1eAn93G3RnJTOzdXO8+VXWCBrpSV2DXwRO8vng7s1bl4TAwrFMCD/VtQWpDvftUWUcDXamrsPfYKaYt2cm7K/ZQUl7BbemNeKhfCzokRVpdmqqHNNCV8oDDJ0p56/udvLV0F8dLyrkxLZZJ/VLpkRKNiM7yqK4NDXSlPKiopIx3Vuxh2pKdHCw+TdemUUzq14J+rRpqsKtap4GuVC0oKavgP1m5vPbtDvKPnqJ1o3Am9UtlUPvG+tANVWs00JWqRWUVDuZkF/DKohy2F56gWUwIv+jTgts7NCY8yN/q8pSP0UBX6hpwOAxfbtzHlIXb+TH/GIF+Nm5Oj2d4p0T6tNTnoCrPuFSg6z3OSnmIzSYMaNeY29o2YvWeI3y8xvmQ60/X7SUi2J9B7RszvFMC3ZpF66PyVK3QEbpStaiswsF32w7ycXY+X27Yz6myChIighjSKYFhHRNp0zhcT6Sqy6KHXJSqA06WlvPVxv18kl3A4q2FlDsMLePDGNYpkaEdE/ROVOUWDXSl6phDxaf57Me9fJJdQNbuIwBkNI1iWKcEBndIIDo0wOIKVV3lNYFeVlZGXl4eJSX6FLsrERQURFJSEv7+emWFN8k9fJI5awv4JDufrfuL8bMJvVvGMaxTArekx+t0vuo8XhPoO3fuJDw8nJiYGD2ueJmMMRw6dIiioiJSUlKsLkddAWMMm/YW8cnafOZkF7D3WAnB/nZubeu8UuaGtFj87XqlTH3nNVe5lJSU0KxZMw3zKyAixMTEoI/2814iQnpCA9ITGvDoba1ZueswH2cXnD00Ex0awOD2jRneOYEuTaL074m6QJ0KdED/kF4F3Xe+w2YTejSPoUfzGP5naFu+3VrIx9n5fJCVy7+W7yYpKphhnRIY3imRtHid/VE51blAV0qdL8DPxi3p8dySHk/x6XLmr9/Hx9n5vLpoO1MWbqdN4wYM75TAkI4JJEQGW12uspBbgS4iA4AXATswzRjzXJX1DwCTgAqcD4ueaIzZ6OFar4mwsDCKi4tr7qiUBcIC/bijaxJ3dE2isOg089YV8El2Ac9+vpnnvthM92bRDO+cyMB2jYgM0Stl6ht3HhJtx/mQ6FuAPJwPiR5VObBFpIEx5rhreSjwkDFmwKW2W91J0U2bNtGmTZsr+XV4jLcHel3Yh+ra23XwBJ9kO6+U2XHwBP52oW+rhgxq34iMptEkRQXrITkfcbUnRbsDOcaYHa6NzQSGAWcD/UyYu4QCV33pzP/M3cDGguM1d7wM6QkNeGpIW7f6GmN45JFH+PzzzxERnnzySe6++2727t3L3XffzfHjxykvL+fVV1/l+uuvZ/z48WRlZSEijBs3jl//+tcerV2pS2kWG8rkm9P45U2prM8/zifZ+cxZW8BXG/cDEBceSOfkSDo3iaJLk0jaJ0Xo5ZA+yJ3f0UQgt9L7PKBH1U4iMgn4DRAA9K9uQyIyEZgI0KRJk8ut9Zr66KOPyM7OZu3atRw8eJBu3brRu3dv3n33XW677TaeeOIJKioqOHnyJNnZ2eTn57N+/XoAjh49am3xqt4SEdonRdA+KYLfDWrDpr3HWZN7lDW7j7Am9yhfugLebhNaNwqnS5MoOjeJpEuTKJrGhOgo3st57J9oY8wUYIqI3AM8CYyups9UYCo4D7lcanvujqRry3fffceoUaOw2+3Ex8fTp08fVq5cSbdu3Rg3bhxlZWUMHz6cTp060bx5c3bs2MHDDz/M4MGDufXWWy2tXSlwhna7xAjaJUZwX8+mgPOpS9m5R1i9+yhrco/w0eo8/rV8NwDRoQGuUbwz4DskRxIWqKN4b+LO71Y+kFzpfZKr7WJmAq9eTVF1We/evVm8eDGffvopY8aM4Te/+Q0///nPWbt2LfPnz+e1117jgw8+YPr06VaXqtQFokMD6N86nv6t4wGocBi2HShizZ6jrHaN4r/efAAAm0DL+HA6VxrFN48N1Zki6zB3An0lkCYiKTiDfCRwT+UOIpJmjNnmejsY2IaXu/HGG3n99dcZPXo0hw8fZvHixTz//PPs3r2bpKQkJkyYwOnTp1m9ejWDBg0iICCAO+64g1atWnHvvfdaXb5SbnEeemlA60YNGNXdeRj02MkysvPOBfy8dQW898MeACKC/enkGsV3bhJFp+RIIoJ1qom6osZAN8aUi0gmMB/nZYvTjTEbRORpIMsYMwfIFJGbgTLgCNUcbvE2P/nJT1i2bBkdO3ZERPjLX/5Co0aNePvtt3n++efx9/cnLCyMGTNmkJ+fz9ixY3E4HAA8++yzFlev1JWLCPGnT8s4+rSMA5wP7thxsJjVe46yZs8R1uw5yotfb+PMBXKpDcPo0uTMCdcoUhuG6SP4LFKn5nLRS+6unu5DdS0UlZSxLu/Y2VH8mj1HOHKyDHBeK39uFB9J5+QoonT2SI/xmrlclFLeITzIn16psfRKjQWcl/nuOnSSNXuOsNo1in9l0XYqHM4BY7vEBtyVkcywjolEhOghmtqiga6UumoiQkpsKCmxoYzokgQ4H+ixLu8Yq3Yf4bMf9/KHTzbwp083MaBdI+7OSKZn8xg9wephGuhKqVoREuBHz+Yx9Gwew6R+qazPP8YHWbnMXpPPJ9kFNIkO4a6MJH7aNZlGEUFWl+sTNNCVUtfEmWviHx/Uhi/W7+P9lbm88OVW/vbVVvq0jOPubsnc1CZe53y/ChroSqlrKsjfzvDOiQzvnMjuQyf4T1Ye/1mVywP/Xk1sWAAjuiRxV0YyqQ3DrC7V62igK6Us0zQmlN/e1opf3ZzG4m2FvL8yl+nf7WTq4h10bRrF3RnJDO7QmFC9Y9UtupeUUpbzs9vO3sFaWHSaj1bn8X5WLo98uI7/mbuBIR0TuKtbMp2TI3W+mUvQQLdIeXk5fn66+5WqKi48kF/0acHE3s1ZtfsI76/M5ZPsAmauzCWtYRh3d0vmJ50TiQkLtLrUOqfuJsrnj8G+Hz27zUbtYeBzNXYbPnw4ubm5lJSUMHnyZCZOnMgXX3zB448/TkVFBbGxsXz99dcUFxfz8MMPn50296mnnuKOO+44b071WbNmMW/ePN566y3GjBlDUFAQa9asoVevXowcOZLJkydTUlJCcHAwb775Jq1ataKiooJHH32UL774ApvNxoQJE2jbti0vvfQSH3/8MQBfffUVr7zyCrNnz/bsPlKqjhARMppFk9EsmqeGtmXeWmeoP/PpJv78xWZubhPPXd2S6Z0Wp3emutTdQLfQ9OnTiY6O5tSpU3Tr1o1hw4YxYcIEFi9eTEpKCocPHwbgj3/8IxEREfz4o/MfniNHjtS47by8PJYuXYrdbuf48eMsWbIEPz8/FixYwOOPP86HH37I1KlT2bVrF9nZ2fj5+XH48GGioqJ46KGHKCwsJC4ujjfffJNx48bV6n5Qqq4IC/RjZPcmjOzehK37i3h/pfPyx8/X76NxRBB3dk3izoxkkqNDrC7VUnU30N0YSdeWl1566ezINzc3l6lTp9K7d29SUlIAiI6OBmDBggXMnDnz7OeioqJq3Padd96J3W4H4NixY4wePZpt27YhIpSVlZ3d7gMPPHD2kMyZ77vvvvv497//zdixY1m2bBkzZszw0K9YKe/RMj6c39+ezqMDWrNg035mrszlHwtzeOmbHHqlxnBXRjK3tW1EkL/d6lKvubob6BZZtGgRCxYsYNmyZYSEhNC3b186derE5s2b3d5G5ZM2JSUl560LDQ09u/z73/+efv36MXv2bHbt2kXfvn0vud2xY8cyZMgQgoKCuPPOO/UYvKrXAvxsDGrfmEHtG5N/9BSzsvL4ICuXyTOziQj25yedE7krI5n0hAZWl3rNaCJUcezYMaKioggJCWHz5s0sX76ckpISFi9ezM6dO88ecomOjuaWW25hypQp/P3vfwech1yioqKIj49n06ZNtGrVitmzZxMeHn7R70pMTATgrbfeOtt+yy238Prrr9OvX7+zh1yio6NJSEggISGBZ555hgULFtT2rlDKayRGBjP55jQe7p/K0u2HmLlyD++u2MNbS3fRPjGCu7olM7RjwmVN9WuMwWGg3OHA4Tj/Z4UxVDjOvcodBofr59n2Kn0q923dKLxWDg9poFcxYMAAXnvtNdq0aUOrVq3o2bMncXFxTJ06lREjRuBwOGjYsCFfffUVTz75JJMmTaJdu3bY7XaeeuopRowYwXPPPcftt99OXFwcGRkZF33o9COPPMLo0aN55plnGDx48Nn2+++/n61bt9KhQwf8/f2ZMGECmZmZAPzsZz+jsLBQZ1RUqho2m3BDWiw3pMVy5EQpH2fn8/7KXH7/8Xr+9OlGEiKCqTCG8gqDw5wfxOcFsiuMa8szw9txr+spUp6k0+d6mczMTDp37sz48eOrXa/7UKnzGWP4Mf8YH63O52Dxaew2cb5E8LMLNhH8bILdZsNu4/yfF/Sp5nXZfWwkRAZd8WWXOn2uj+jatSuhoaH89a9/tboUpbyGiNAhKZIOSZFWl1LrNNC9yKpVq6wuQSlVh7k1rZmIDBCRLSKSIyKPVbP+NyKyUUTWicjXInLFB4esOgTkC3TfKVW/1RjoImIHpgADgXRglIikV+m2BsgwxnQAZgF/uZJigoKCOHTokAbTFTDGcOjQIYKCdF5ppeordw65dAdyjDE7AERkJjAM2HimgzFmYaX+y4Ereux9UlISeXl5FBYWXsnH672goCCSkpKsLkMpZRF3Aj0RyK30Pg/ocYn+44HPq1shIhOBiQBNmjS5YL2/v//ZuzGVUkpdHo8+GkRE7gUygOerW2+MmWqMyTDGZMTFxXnyq5VSqt5zZ4SeDyRXep/kajuPiNwMPAH0Mcac9kx5Siml3OXOCH0lkCYiKSISAIwE5lTuICKdgdeBocaYA54vUymlVE3culNURAYBfwfswHRjzJ9E5GkgyxgzR0QWAO2Bva6P7DHGDK1hm4XA7iusOxY4eIWf9UW6P86n++Mc3Rfn84X90dQYU+0xa8tu/b8aIpJ1sVtf6yPdH+fT/XGO7ovz+fr+8OhJUaWUUtbRQFdKKR/hrYE+1eoC6hjdH+fT/XGO7ovz+fT+8Mpj6EoppS7krSN0pZRSVWigK6WUj/C6QK9pKt/6QkSSRWSha9riDSIy2eqa6gIRsYvIGhGZZ3UtVhORSBGZJSKbRWSTiFxndU1WEZFfu/6erBeR90TEJ6cl9apAd3Mq3/qiHPgvY0w60BOYVI/3RWWTgU1WF1FHvAh8YYxpDXSknu4XEUkEfolziu92OG+QHGltVbXDqwKdSlP5GmNKgTNT+dY7xpi9xpjVruUinH9ZE62tyloikgQMBqZZXYvVRCQC6A28AWCMKTXGHLW0KGv5AcEi4geEAAUW11MrvC3Qq5vKt16HGICINAM6AyssLsVqfwceARwW11EXpACFwJuuQ1DTRCTU6qKsYIzJB14A9uCcnuSYMeZLa6uqHd4W6KoKEQkDPgR+ZYw5bnU9VhGR24EDxhh98KqTH9AFeNUY0xk4AdTLc04iEoXzf/IpQAIQ6prq2+d4W6C7NZVvfSEi/jjD/B1jzEdW12OxXsBQEdmF81BcfxH5t7UlWSoPyDPGnPlf2yycAV8f3QzsNMYUGmPKgI+A6y2uqVZ4W6DXOJVvfSEigvP46CZjzN+srsdqxpjfGWOSjDHNcP65+MYY45OjMHcYY/YBuSLSytV0E5UeG1nP7AF6ikiI6+/NTfjoCWJ3HnBRZxhjykUkE5jPual8N1hcllV6AfcBP4pItqvtcWPMZ9aVpOqYh4F3XIOfHcBYi+uxhDFmhYjMAlbjvDpsDT46BYDe+q+UUj7C2w65KKWUuggNdKWU8hEa6Eop5SM00JVSykdooCullI/QQFdKKR+hga6UUj7i/wOs672n42Ni/gAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.9885\n"
     ]
    }
   ],
   "source": [
    "# 导入必要的库\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "from tqdm import *\n",
    "import numpy as np\n",
    "\n",
    "# 设置随机种子\n",
    "torch.manual_seed(0)\n",
    "\n",
    "# 定义模型、优化器、损失函数\n",
    "model = LeNet()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.02)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# 设置数据变换和数据加载器\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),  # 将数据转换为张量\n",
    "    transforms.Normalize((0.5,), (0.5,))  # 对数据进行归一化\n",
    "])\n",
    "train_dataset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)  # 加载训练数据\n",
    "train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)  # 实例化训练数据加载器\n",
    "test_dataset = datasets.MNIST(root='../data', train=False, download=True, transform=transform)  # 加载测试数据\n",
    "test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)  # 实例化测试数据加载器\n",
    "\n",
    "# 设置epoch数并开始训练\n",
    "num_epochs = 10  # 设置epoch数\n",
    "loss_history = []  # 创建损失历史记录列表\n",
    "acc_history = []   # 创建准确率历史记录列表\n",
    "\n",
    "for epoch in tqdm(range(num_epochs)):\n",
    "    # 批量训练\n",
    "    model.train()\n",
    "    for inputs, labels in train_loader:\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "    # 测试模型\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        total_test_loss = 0\n",
    "        total_test_correct = 0\n",
    "        for inputs, labels in test_loader:\n",
    "            outputs = model(inputs)\n",
    "            total_test_loss += criterion(outputs, labels).item()\n",
    "            total_test_correct += (outputs.argmax(1) == labels).sum().item()\n",
    "        \n",
    "    # 输出测试集上的损失和准确率\n",
    "    loss_history.append(np.log10(total_test_loss))  # 将损失加入损失历史记录列表，由于数值较大，这里取对数\n",
    "    acc_history.append(total_test_correct / len(test_dataset))# 将准确率加入准确率历史记录列表\n",
    "\n",
    "# 使用图表库（例如Matplotlib）绘制损失和准确率的曲线图\n",
    "import matplotlib.pyplot as plt\n",
    "plt.plot(loss_history, label='loss')\n",
    "plt.plot(acc_history, label='accuracy')\n",
    "plt.legend()\n",
    "plt.show()\n",
    "\n",
    "# 输出准确率\n",
    "print(\"Accuracy:\", acc_history[-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cef9ed7b",
   "metadata": {},
   "source": [
    "### 8.3.4 小结"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d2fbad8",
   "metadata": {},
   "source": [
    "* 通过对输入和中间网络层的输出进行标准化处理后，减少了内部神经元分布的改变，使降低了不同样本间值域的差异性，得大部分的数据都其处在非饱和区域，从而保证了梯度能够很好的回传，避免了梯度消失和梯度爆炸。\n",
    "* 通过减少梯度对参数或其初始值尺度的依赖性，使得我们可以使用较大的学习速率对网络进行训练，从而加速网络的收敛。\n",
    "* 由于在训练的过程中批量标准化所用到的均值和方差是在一小批样本(mini-batch)上计算的，而不是在整个数据集上，所以均值和方差会有一些小噪声产生，同时缩放过程由于用到了含噪声的标准化后的值，所以也会有一点噪声产生，这迫使后面的神经元单元不过分依赖前面的神经元单元。所以，它也可以看作是一种正则化手段，提高了网络的泛化能力，使得我们可以减少或者取消 Dropout，从而优化网络结构。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
